#!/usr/bin/env python

import numpy
import os
import sys
import optparse

def load_binary_file(fname, dtype='float32', modify=False):
    size = os.stat(fname).st_size

    if size:
        data = numpy.memmap(fname, dtype=dtype,
                            mode='r+' if modify else 'c')
        return data
    else:
        return []

o = optparse.make_option

opts = [
    o('-c', '--cut-off', action='store', dest='cutoff', type='int',
      help='max number of samples to use'),

    o(None, '--count', action='store_true', dest='count',
      help='just report the number of samples in each file'),

    o(None, '--only-min', action='store_true', dest='only_min',
      help='When counting, report only the minimum number of samples.'),

    o(None, '--output-dir', action='store', dest='output_dir',
      help='directory where output files should be stored.')
    ]

defaults = {
    'cutoff'  : None,
    'count'   : False,
    'only_min' : False,
    'output_dir' : None,
    }

options = None

def load_files(fnames):
    return [load_binary_file(f) for f in fnames]

def shuffle_truncate(arrays, fnames, target_length=None):
    # Determine how many samples we can use.
    if target_length:
        shortest = target_length
    else:
        shortest = min([len(a) for a in arrays])
        print "Selecting %d samples from each data file." % shortest

    # Make sure we'll select samples from all
    # parts of the data file.
    for a, n in zip(arrays, fnames):
        if len(a) > shortest:
            # Gotta be uniformly shuffled.
            print "Shuffling %s ..." % n
            numpy.random.shuffle(a)
        else:
            # not enough samples
            print "Not shuffling %s." % n

    # Now select the same number of samples from each file.
    truncated = [a[:shortest] for a in arrays]

    return truncated

def store_files(arrays, fnames):
    for a, fn in zip(arrays, fnames):
        print 'Storing %s.' % fn
        fd = open(fn, 'wb')
        a.tofile(fd)
        fd.close()

def target_file(fname, want_ext):
    f = os.path.basename(fname)
    if options.output_dir:
        d = options.output_dir
    else:
        d = os.path.dirname(fname)
    if not want_ext is None:
        name, ext = os.path.splitext(f)
        f = "%s.%s" % (name, want_ext)
    return os.path.join(d, f)

def shuffle_truncate_store(files, cutoff=None, ext='sf32'):
    data  = load_files(files)
    trunc = shuffle_truncate(data, files, target_length=cutoff)
    names = [target_file(f, ext) for f in files]
    store_files(trunc, names)

def shuffle_truncate_store_individually(files, cutoff):
    fmt = "%%0%dd" % len(str(len(files)))
    for i, f in enumerate(files):
        print ("["  + fmt + "/%d] %s") % (i+1, len(files),
                                          os.path.basename(f))
        sys.stdout.flush()
        name = target_file(f, 'sbn')
        fs = os.stat(f)
        if os.path.exists(name):
            print "Skipping since %s exists." % name
        elif fs.st_size == 0:
            print "Skipping since trace is empty."
        else:
            shuffle_truncate_store([f], cutoff=cutoff)

def report_sample_counts(files):
    counts = []
    fmt = "%%0%dd" % len(str(len(files)))
    for i, f in enumerate(files):
        d = load_binary_file(f)
        counts.append(len(d))
        if not options.only_min:
            print ("["  + fmt + "/%d] %8d %s") % (i+1, len(files), len(d), f)
            sys.stdout.flush()
        del d
    if options.only_min:
        print min(counts)

if __name__ == '__main__':
    # FIXME: would be nicer with argparse
    parser = optparse.OptionParser(option_list=opts)
    parser.set_defaults(**defaults)
    (options, files) = parser.parse_args()

    if not files:
        print "Usage: ft-shuffle-truncate data1.float32 data2.float32 data3.float32 ..."
    else:
        if options.count:
            report_sample_counts(files)
        elif options.cutoff:
            shuffle_truncate_store_individually(files, options.cutoff)
        else:
            shuffle_truncate_store(files)
